"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import copy
import hashlib
import math
import random
import threading
import time
from collections import deque
from typing import List

import orjson
import redis

from fastdeploy.engine.request import (CompletionOutput, Request,
                                       RequestMetrics, RequestOutput)
from fastdeploy.utils import scheduler_logger as logger


class SplitWiseSchedulerConfig(object):
    """SplitWise Scheduler Configuration"""

    def __init__(
            self,
            nodeid=None,
            host="127.0.0.1",  # redis host
            port=6379,  # redis port
            password=None,  # redis password
            topic="fd",  # redis topic
            ttl=900,
            release_load_expire_period=600,  #s
            sync_period=5,  #ms
            expire_period=3000,  #ms
            clear_expired_nodes_period=60,  #s
            reader_parallel=4,
            reader_batch_size=200,
            writer_parallel=4,
            writer_batch_size=200,
            **kwargs):

        if nodeid is None:
            import uuid
            nodeid = str(uuid.uuid4())
        self.nodeid = nodeid

        self.redis_host = host
        self.redis_port = port
        self.redis_password = password
        self.redis_topic = topic
        self.ttl = ttl
        self.release_load_expire_period = release_load_expire_period

        self.sync_period = sync_period
        self.expire_period = expire_period / 1000.
        self.clear_expired_nodes_period = clear_expired_nodes_period
        self.reader_parallel = reader_parallel
        self.reader_batch_size = reader_batch_size
        self.writer_parallel = writer_parallel
        self.writer_batch_size = writer_batch_size

    def check(self):
        """check argument"""
        pass

    def print(self):
        """
        print config
        """
        logger.info("LocalScheduler Configuration Information :")
        for k, v in self.__dict__.items():
            logger.info("{:<20}:{:<6}{}".format(k, "", v))
        logger.info(
            "=============================================================")


class SplitWiseScheduler(object):
    """
       SplitWise Scheduler
    """

    def __init__(self, config):
        self.scheduler = APIScheduler(config)
        self.infer = InferScheduler(config)

    def start(self, role, host, disaggregated):
        """
            Start APIScheduler and InferScheduler backup threads
        """
        logger.info(
            f"Scheduler Start With: role:{role}, host:{host}, disaggregated:{disaggregated}"
        )
        self.infer.start(role, host, disaggregated)
        self.scheduler.start()

    def reset_nodeid(self, nodeid):
        """
            reset node id
        """
        self.scheduler.nodeid = nodeid
        self.infer.nodeid = nodeid

    def put_requests(self, reqs: List[Request]):
        """
            put requests to global splitwise scheduler
        """
        return self.scheduler.put_requests(reqs)

    def get_results(self, request_ids=[]):
        """
            get results from global splitwise scheduler
        """
        return self.scheduler.get_results()

    def get_requests(self,
                     available_blocks,
                     block_size,
                     reserved_output_blocks,
                     max_num_batched_tokens,
                     batch=1):
        """
            get scheduled requests from global spltiwise scheduler
        """
        if available_blocks <= reserved_output_blocks or batch < 1:
            logger.info(
                f"Scheduler's resource are insufficient: available_blocks={available_blocks} "
                f"reserved_output_blocks={reserved_output_blocks} batch={batch} "
                f"max_num_batched_tokens={max_num_batched_tokens}")
            return []
        return self.infer.get_requests(available_blocks, block_size,
                                       reserved_output_blocks,
                                       max_num_batched_tokens, batch)

    def put_results(self, results: List[RequestOutput]):
        """
            put results to global splitwise scheduler
        """
        return self.infer.put_results(results)


class NodeInfo(object):
    """
        Infer Node Info: load, rdma/ipc info
    """

    @classmethod
    def load_from(self, nodeid, info):
        """
            load node info from seiralized string
        """
        health = orjson.loads(info)
        ts = health["ts"]
        role = health["role"]
        load = int(health["load"])
        host = health["host"]
        disaggregated = health["disaggregated"]
        return NodeInfo(nodeid, role, host, disaggregated, load, ts)

    def __init__(self, nodeid, role, host, disaggregated, load,
                 ts=time.time()):
        self.nodeid = nodeid
        self.ts = ts
        self.host = host
        self.disaggregated = disaggregated
        self.role = role
        self.lock = threading.Lock()
        self.load = load
        self.reqs = dict()

    def __repr__(self):
        return f"{self.nodeid}({self.load})"

    def expired(self, expire_period):
        """
            APIScheduler used to check if the node is expired
        """
        now = time.time()
        return (now - self.ts) > expire_period

    def serialize(self):
        """
            InferScheduler used to sync load
        """
        self.ts = time.time()
        health = {
            "ts": self.ts,
            "role": self.role,
            "load": self.load,
            "host": self.host,
            "disaggregated": self.disaggregated
        }
        return orjson.dumps(health)

    def __lt__(self, other):
        return self.load < other.load

    def expire_reqs(self, ttl):
        """
            InferScheduler used to clear expired reqs
        """
        cur_time = time.time()
        with self.lock:
            expire_reqs = set()
            for req_id, pairs in self.reqs.items():
                load, arrival_time = pairs
                if cur_time - arrival_time > ttl:
                    logger.error(
                        f"InferScheduler Expire Reqs({req_id}), arrival({arrival_time}), ttl({ttl})"
                    )
                    expire_reqs.add((req_id, load))
            for req_id, load in expire_reqs:
                if req_id in self.reqs:
                    self.load -= load
                    del self.reqs[req_id]

    def add_req(self, req_id, load):
        """
            InferScheduler used to record scheduled reqs(waiting or running)
        """
        with self.lock:
            if req_id not in self.reqs:
                self.reqs[req_id] = [load, time.time()]
                self.load += load

    def update_req_timestamp(self, req_ids):
        """
            InferScheduler used to update reqs timestamp
        """
        cur_time = time.time()
        with self.lock:
            for req_id in req_ids:
                if req_id in self.reqs:
                    self.reqs[req_id][1] = cur_time

    def finish_req(self, req_id):
        """
            InferScheduler used to clear finished reqs
        """
        with self.lock:
            if req_id in self.reqs:
                load = self.reqs[req_id][0]
                self.load -= load
                del self.reqs[req_id]


class ResultReader(object):
    """
        ResultReader use an async thread to continue get infer result from redis
    """

    def __init__(self, client, idx, batch=200, ttl=900, group=""):
        self.idx = idx
        self.batch = batch
        self.client = client
        self.data = deque()
        self.ttl = ttl
        self.group = group

        self.reqs = dict()
        self.out_buffer = dict()
        self.lock = threading.Lock()

        self.thread = threading.Thread(target=self.run)
        self.thread.start()

    def add_req(self, req):
        """
            add a req to reader, reader will async fetch infer result from redis
        """
        with self.lock:
            self.reqs[req.request_id] = {"arrival_time": req.arrival_time}
            self.out_buffer[req.request_id] = []

    def read(self):
        """
            batch read infer results
            returns: dict(req_id, [ResultOutput])
        """
        items = []
        size = len(self.data)
        for i in range(size):
            items.append(self.data.pop())

        outputs = dict()
        group_tokens = dict()
        finish_reqs = set()
        for item in items:
            req_id = item.request_id

            is_error = item.error_code != 200

            if is_error or item.finished:
                finish_reqs.add(req_id)

            if is_error or item.outputs.send_idx == 0:
                outputs[req_id] = [item]
                continue

            if req_id not in group_tokens:
                group_tokens[req_id] = []
            group_tokens[req_id].append(item)

        with self.lock:
            for key in finish_reqs:
                if key in self.reqs:
                    del self.reqs[key]

            for req_id, items in outputs.items():
                if req_id in self.out_buffer:
                    items.extend(self.out_buffer[req_id])
                    del self.out_buffer[req_id]

            for req_id, items in group_tokens.items():
                if req_id in self.out_buffer:
                    self.out_buffer[req_id].extend(items)
                    continue

                if req_id not in outputs:
                    outputs[req_id] = []
                outputs[req_id].extend(items)

            return outputs

    def run(self):
        """
            continue fetch infer results from redis
        """
        while True:
            try:
                keys = []
                cur_time = time.time()
                with self.lock:
                    expired_reqs = set()
                    for req_id, req in self.reqs.items():
                        if cur_time - req.get("arrival_time",
                                              cur_time) > self.ttl:
                            result = RequestOutput(
                                request_id=req_id,
                                prompt="",
                                prompt_token_ids=[],
                                outputs=CompletionOutput(-1, -1, []),
                                metrics=RequestMetrics(
                                    arrival_time=req["arrival_time"]),
                                error_code=500,
                                error_msg=f"Req({req_id}) is expired({self.ttl})")
                            self.data.appendleft(result)

                            logger.error(
                                f"Req({req_id}) is expired({self.ttl})")
                            expired_reqs.add(req_id)
                            continue
                        keys.append(req_id)
                    for req_id in expired_reqs:
                        del self.reqs[req_id]

                if len(keys) == 0:
                    time.sleep(0.01)
                    continue

                total = self.sync_results(keys)
                if total == 0:
                    time.sleep(0.01)
            except Exception as e:
                logger.error(
                    f"ResultsReader{self.idx} sync results error: {str(e)}")

    def sync_results(self, keys):
        """
            fetch infer results from redis for the give keys
        """
        total = 0
        if self.group != "":
            keys = [self.group]
        for key in keys:
            #logger.info(f"Sync Results from Redis {key}")
            results = self.client.rpop(key, self.batch)
            if results is None or len(results) == 0:
                continue
            #logger.info(f"Rpop {key} {self.idx}: {len(results)}")
            total += len(results)
            for result in results:
                try:
                    # logger.info(f"Scheduler Get Results: {result.request_id}")
                    data = orjson.loads(result)
                    result = RequestOutput.from_dict(data)
                    self.data.appendleft(result)
                except Exception as e:
                    logger.error(f"Parse Result Error:{e}, {result}")
        return total


class APIScheduler(object):
    """
        APIScheduler: put requests to global schedule, and get recording infer results
    """

    def __init__(self, config):
        self.nodeid = config.nodeid
        self.reader_parallel = config.reader_parallel
        self.reader_batch_size = config.reader_batch_size
        self.expire_period = config.expire_period
        self.clear_expired_nodes_period = config.clear_expired_nodes_period
        self.ttl = config.ttl
        self.topic = config.redis_topic
        self.cluster_key = f"{self.topic}.cluster"

        self.client = redis.Redis(host=config.redis_host,
                                  port=config.redis_port,
                                  password=config.redis_password)

        self.req_cond = threading.Condition()
        self.reqs_queue = deque()
        self.readers = []

    def start(self):
        """
            start backup threads
        """
        for i in range(self.reader_parallel):
            group = f"{self.nodeid}-{i}"
            reader = ResultReader(self.client, i, self.reader_batch_size,
                                  self.ttl, group)
            self.readers.append(reader)

        self.clear_expired_nodes_thread = threading.Thread(
            target=self.loop_clear_expired_nodes)
        self.clear_expired_nodes_thread.start()

        self.schedule_thread = threading.Thread(target=self.loop_schedule)
        self.schedule_thread.start()

    def put_requests(self, reqs):
        """
            put requests to local req queue. reqs will be async scheduled
        """
        ret = []
        with self.req_cond:
            for req in reqs:
                self.reqs_queue.appendleft(req)
                ret.append((req.request_id, None))
            self.req_cond.notify_all()
        return ret

    def get_results(self):
        """
            get infer results from local queue. results is async fetched from redis
        """
        outputs = dict()
        for reader in self.readers:
            outs = reader.read()
            outputs.update(outs)
        return outputs

    def loop_schedule(self):
        """
            loop schedule req based on global load states.
        """
        reader_idx = 0
        while True:
            try:
                with self.req_cond:
                    if len(self.reqs_queue) == 0:
                        self.req_cond.wait()

                pnodes, dnodes, mnodes = self.sync_cluster()
                if len(mnodes) == 0 and (len(pnodes) == 0 or len(dnodes) == 0):
                    logger.error(
                        f"No Schedule Nodes: mixed:{len(mnodes)}, prefill:{len(pnodes)}, decode:{len(dnodes)}"
                    )
                    time.sleep(1)
                    continue

                req = self.reqs_queue.pop()

                reader = self.readers[reader_idx]
                reader.add_req(req)
                group = self.readers[reader_idx].group
                reader_idx = (reader_idx + 1) % len(self.readers)

                self.schedule(req, pnodes, dnodes, mnodes, group)
            except IndexError:
                continue
            except Exception as e:
                logger.error(f"APIScheduler Schedule req error: {str(e)}")

    def schedule(self, req, pnodes, dnodes, mnodes, group=""):
        """
            schedule an req to according redis node queue
        """
        pnodes.extend(mnodes)
        pnodes.sort()
        pnode = self.select_pd(req, pnodes, "prefill")
        if pnode.role == "mixed":
            req.disaggregate_info = None
            req_dict = req.to_dict()
            req_dict["group"] = group
            req_str = orjson.dumps(req_dict)
            pkey = f"ReqQ_{pnode.nodeid}"
            #logger.info(f"Schedule Req {req_str} to Mixed")
            self.client.lpush(pkey, req_str)
        else:
            dnodes.sort()
            dnode = self.select_pd(req, dnodes, "decode")
            disaggregated = copy.deepcopy(dnode.disaggregated)
            transfer_protocol = disaggregated["transfer_protocol"]
            if len(
                    transfer_protocol
            ) > 1 and "ipc" in transfer_protocol and "rdma" in transfer_protocol:
                if pnode.host == dnode.host:
                    disaggregated["transfer_protocol"] = "ipc"
                else:
                    disaggregated["transfer_protocol"] = "rdma"
            else:
                disaggregated["transfer_protocol"] = transfer_protocol[0]
            req.disaggregate_info = disaggregated
            pkey, dkey = f"ReqQ_{pnode.nodeid}", f"ReqQ_{dnode.nodeid}"
            req_dict = req.to_dict()
            req_dict["group"] = group
            req_str = orjson.dumps(req_dict)
            #logger.info(f"Schedule Req {req_str}")
            self.client.lpush(dkey, req_str)
            self.client.lpush(pkey, req_str)

    def sync_cluster(self):
        """
            fetch cluster load states from redis
        """
        clusters = self.client.hgetall(self.cluster_key)
        pnodes, dnodes, mnodes = [], [], []
        for nodeid, info in clusters.items():
            node = NodeInfo.load_from(nodeid.decode(), info)
            if node.expired(self.expire_period):
                logger.error(f"node {nodeid} is expired: {info}")
                continue
            if node.role == "prefill":
                pnodes.append(node)
            elif node.role == "decode":
                dnodes.append(node)
            elif node.role == "mixed":
                mnodes.append(node)
            else:
                logger.error(f"Invalid Role: {node.role} {info}")
        return pnodes, dnodes, mnodes

    def loop_clear_expired_nodes(self):
        """
            loop clear expired node's dirty data in redis
        """
        while True:
            try:
                expire_nodes = set()
                clusters = self.client.hgetall(self.cluster_key)
                for nodeid, info in clusters.items():
                    node = NodeInfo.load_from(nodeid.decode(), info)
                    if node.expired(self.clear_expired_nodes_period):
                        expire_nodes.add(nodeid)
                for nodeid in expire_nodes:
                    #logger.info(f"clear expired nodes: {nodeid}")
                    self.client.hdel(self.cluster_key, nodeid)
                time.sleep(self.clear_expired_nodes_period)
            except Exception:
                logger.error(
                    "APIScheduler clear expired nodes error: {str(e)}")

    def select_pd(self, req, nodes, role):
        """
            select a prefill/decode/mixed node based on load states
        """

        def select(req, nodes, blur_step):
            min_load = nodes[0].load
            blur_max = min_load + blur_step
            blur_idx = 0
            for idx, node in enumerate(nodes):
                if node.load >= blur_max:
                    break
                blur_idx = idx
            node = random.choice(nodes[:blur_idx + 1])
            logger.info(
                f"Schedule Req {req.request_id}(len:{req.prompt_token_ids_len}) to {node}"
            )
            return node

        if role == "prefill" or role == "mixed":
            size = req.prompt_token_ids_len
            rate = 2 if size < 1000 else 10
            pblur_step = max(100, min(500, int(size / rate)))
            pnode = select(req, nodes, pblur_step)
            return pnode
        elif role == "decode":
            dblur_step = min(len(nodes), 10)
            dnode = select(req, nodes, dblur_step)
            return dnode

        raise Exception(f"Invalid Role: {role}")


class ResultWriter(object):
    """
        ResultWriter use an async thread to continue writer infer results to redis
    """

    def __init__(self, client, idx, batch, ttl=900):
        self.idx = idx
        self.batch = batch
        self.client = client
        self.data = deque()
        self.cond = threading.Condition()
        self.thread = threading.Thread(target=self.run)
        self.ttl = ttl

    def start(self):
        """start backup thread"""
        self.thread.start()

    def put(self, key, items):
        """
            put infer results to writer
        """
        with self.cond:
            for item in items:
                self.data.appendleft((key, item))
            self.cond.notify_all()

    def run(self):
        """
            continue batch write infer results to redis
        """
        while True:
            try:
                with self.cond:
                    size = len(self.data)
                    if size == 0:
                        self.cond.wait()
                #qsize = size
                size = min(size, self.batch)
                #logger.info(f"Writer {self.idx} Queue Size: {qsize}, Cur Size: {size}")
                groups = dict()
                for i in range(size):
                    key, item = self.data.pop()
                    if key not in groups:
                        groups[key] = []
                    groups[key].append(item)
                for key, items in groups.items():
                    #s = time.time()
                    with self.client.pipeline() as pipe:
                        pipe.multi()
                        pipe.lpush(key, *items)
                        pipe.expire(key, math.ceil(self.ttl))
                        pipe.execute()
                    #self.client.lpush(key, *items)
                    #e = time.time()
                    #logger.info(f"Lpush {self.idx}: {key} used {e-s} {len(items)} items")
            except Exception as e:
                logger.error(f"ResultWriter write error: {str(e)}")


class InferScheduler(object):
    """
        InferScheduler: get scheduled requests to local queue, write results to redis
    """

    def __init__(self, config):
        self.nodeid = config.nodeid
        self.writer_parallel = config.writer_parallel
        self.writer_batch_size = config.writer_batch_size
        self.sync_period = config.sync_period
        self.topic = config.redis_topic
        self.cluster_key = f"{self.topic}.cluster"
        self.ttl = config.ttl
        self.release_load_expire_period = config.release_load_expire_period

        self.client = redis.Redis(host=config.redis_host,
                                  port=config.redis_port,
                                  password=config.redis_password)

        self.reqs_queue = deque()
        self.writers = []

    def start(self, role, host, disaggregated):
        """
            start backup threads
        """
        for i in range(self.writer_parallel):
            writer = ResultWriter(self.client, i, self.writer_batch_size,
                                  self.ttl)
            writer.start()
            self.writers.append(writer)

        self.getreq_thread = threading.Thread(target=self.loop_get_reqs)
        self.getreq_thread.start()

        self.role = role
        self.host = host
        self.node = NodeInfo(self.nodeid, role, host, disaggregated, 0)

        self.report_thread = threading.Thread(target=self.routine_report)
        self.report_thread.start()

        self.expire_reqs_thread = threading.Thread(
            target=self.loop_expire_reqs)
        self.expire_reqs_thread.start()

    def routine_report(self):
        """
            routine report node info: load, health
        """
        while True:
            try:
                info = self.node.serialize()
                self.client.hset(self.cluster_key, self.nodeid, info)
                time.sleep(self.sync_period / 1000.)
            except Exception as e:
                logger.error(f"InferScheduler routine report error: {str(e)}")

    def loop_expire_reqs(self):
        """
            loop clear expired reqs
        """
        while True:
            try:
                self.node.expire_reqs(self.release_load_expire_period)
                time.sleep(60)
            except Exception:
                logger.error("InferScheduler expire reqs error: {e}")

    def loop_get_reqs(self):
        """
            loop get global scheduled reqs to local queue
        """

        def select_writer(req):
            req_id = req.request_id
            md5 = hashlib.md5()
            md5.update(req_id.encode())
            writer_idx = int(md5.hexdigest(), 16) % len(self.writers)
            return writer_idx

        batch = 50
        while True:
            try:
                key = f"ReqQ_{self.nodeid}"
                reqs = self.client.rpop(key, batch)
                if reqs is None:
                    ret = self.client.brpop([key], timeout=1)
                    if ret is None:
                        continue
                    reqs = [ret[1]]

                for req_str in reqs:
                    req = orjson.loads(req_str)
                    group = req.get("group", "")
                    req = Request.from_dict(req)
                    writer_idx = select_writer(req)
                    logger.info(
                        f"Infer Scheduler Get Req: {req.request_id} writer idx {writer_idx}"
                    )
                    req.request_id = f"{req.request_id}#{writer_idx}#{group}"
                    if self.role == "prefill" or self.role == "mixed":
                        self.reqs_queue.append(req)
                        self.node.add_req(req.request_id,
                                          req.prompt_token_ids_len)
                    else:
                        self.node.add_req(req.request_id, 1)
            except Exception as e:
                logger.error(f"InferScheduler loop get reqs error: {str(e)}")

    def get_requests(self, available_blocks, block_size,
                     reserved_output_blocks, max_num_batched_tokens, batch):
        """
            get scheduled reqs from local reqs queue
        """
        if len(self.reqs_queue) == 0:
            return []

        reqs = []
        required_blocks = 0
        current_prefill_tokens = 0
        cur_time = time.time()
        for i in range(batch):
            try:
                req = self.reqs_queue.popleft()
                if cur_time - req.arrival_time > self.ttl:
                    logger.error(
                        f"req({req.request_id}) is expired({self.ttl}) when InferScheduler Get Requests"
                    )
                    self.node.finish_req(req.request_id)
                    continue
                current_prefill_tokens += req.prompt_token_ids_len
                required_input_blocks = (req.prompt_token_ids_len +
                                         block_size - 1) // block_size
                required_blocks += required_input_blocks + reserved_output_blocks
                if required_blocks > available_blocks or current_prefill_tokens > max_num_batched_tokens:
                    self.reqs_queue.appendleft(req)
                    return reqs
                #logger.info(f"Get Requests from Scheduler: {req.request_id}")
                reqs.append(req)
            except Exception:
                return reqs
        return reqs

    def put_results(self, results):
        """
            put infer results to according writer's local queue
        """
        groups = dict()
        req_ids = set()
        for result in results:
            if result.error_code != 200 or result.finished:
                self.node.finish_req(result.request_id)
                logger.info(
                    f"{result.request_id} finished, node load is {self.node.load}"
                )

            req_ids.add(result.request_id)

            req_id, idx, group = result.request_id.split("#")
            result.request_id = req_id

            key = (req_id if group == "" else group, int(idx))
            if key not in groups:
                groups[key] = list()

            if self.role == "prefill" and result.outputs.send_idx == 0:
                result.finished = False

            result_str = orjson.dumps(result.to_dict())
            #if self.role == "prefill" or result.error_code != 200 or result.finished:
            #    logger.info(f"Infer Put Finish Result: {result_str}")
            groups[key].append(result_str)

        self.node.update_req_timestamp(req_ids)

        for key, outputs in groups.items():
            req_id, idx = key
            self.writers[idx].put(req_id, outputs)
